#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import itertools
import os

from gen_elemwise_multi_type_utils import (  # isort: skip; isort: skip
    MODES,
    QINT32_MODES,
    SUPPORT_DTYPES,
    SUPPORT_QINT32_DTYPES,
)


def generate(modes, support_dtypes, output, cpp_ext):
    for anum, ctype in itertools.product(modes.keys(), support_dtypes):
        print("{} : {}".format(anum, ctype))
        src_ctype = ctype[0]
        dst_ctype = ctype[1]
        for mode in modes[anum]:
            formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode)
            fname = "{}_{}_{}.{}".format(mode, src_ctype, dst_ctype, cpp_ext)
            fname = os.path.join(output, fname)
            with open(fname, "w") as fout:
                w = lambda s: print(s, file=fout)
                w("// generated by gen_elemwise_multi_type_kern_impls.py")

                w("#define KERN_IMPL_MODE(cb) {}".format(formode))
                w("#define KERN_IMPL_ARITY {}".format(anum))
                w("#define KERN_IMPL_STYPE {}".format(src_ctype))
                w("#define KERN_IMPL_DTYPE {}".format(dst_ctype))
                w('#include "../kern_impl.inl"')

            print("generated {}".format(fname))


def main():
    parser = argparse.ArgumentParser(
        description="generate elemwise impl files",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--type",
        type=str,
        choices=["cuda"],
        default="cuda",
        help="generate cuda kernel file",
    )
    parser.add_argument("output", help="output directory")
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    assert args.type == "cuda"
    if args.type == "cuda":
        cpp_ext = "cu"

    generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext)
    generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext)
    os.utime(args.output)


if __name__ == "__main__":
    main()
